# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Union

from tensorrt_llm.llmapi.disagg_utils import ServerRole

COUNTER_METRICS = [
    ("total_requests", "Total number of requests"),
    ("error_requests", "Total number of error requests"),
    ("retry_requests", "Total number of retry requests"),
    ("completed_requests", "Total number of completed requests"),
]
# fmt: off
LONG_TIME_BUCKETS = [
    0.1, 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0,
     60.0, 120.0, 240.0, 480.0, 960.0, 1920.0,
]
SHORT_TIME_BUCKETS = [
    0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0,
    7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, 2560.0,
]
# fmt: on
HISTOGRAM_METRICS = [
    (
        "first_token_latency_seconds",
        "Histogram of latency from first token to completion in seconds",
        SHORT_TIME_BUCKETS,
    ),
    (
        "complete_latency_seconds",
        "Histogram of latency from request arrival to last token in seconds",
        LONG_TIME_BUCKETS,
    ),
    (
        "per_token_latency_seconds",
        "Histogram of latency from request arrival to completion in seconds",
        SHORT_TIME_BUCKETS,
    ),
]

MetricsTypeLiteral = Literal["counter", "histogram"]


@dataclass
class MetricsDefinition:
    name: str
    description: str
    type: MetricsTypeLiteral
    buckets: Optional[List[float]] = None


METRICS_DEFINITIONS = [
    MetricsDefinition("total_requests", "Total number of requests", "counter"),
    MetricsDefinition("error_requests", "Total number of error requests", "counter"),
    MetricsDefinition("retry_requests", "Total number of retry requests", "counter"),
    MetricsDefinition("completed_requests", "Total number of completed requests", "counter"),
    MetricsDefinition(
        "first_token_latency_seconds",
        "Histogram of latency from first token to completion in seconds",
        "histogram",
        SHORT_TIME_BUCKETS,
    ),
    MetricsDefinition(
        "complete_latency_seconds",
        "Histogram of latency from request arrival to last token in seconds",
        "histogram",
        LONG_TIME_BUCKETS,
    ),
    MetricsDefinition(
        "per_token_latency_seconds",
        "Histogram of latency from request arrival to completion in seconds",
        "histogram",
        SHORT_TIME_BUCKETS,
    ),
]

ROLE_TO_CLIENT_TYPE = {
    ServerRole.CONTEXT: "ctx",
    ServerRole.GENERATION: "gen",
    ServerRole.MM_ENCODER: "mme",
}


class ClientMetricsCollector:
    def __init__(self, role: ServerRole):
        self._role = role
        # import lazily to avoid breaking `set_prometheus_multiproc_dir`
        from prometheus_client import Counter, Histogram

        def instance_metric(definition: MetricsDefinition) -> Union[Counter | Histogram]:
            name = f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
            if definition.type == "counter":
                return Counter(name, definition.description)
            elif definition.type == "histogram":
                return Histogram(name, definition.description, buckets=definition.buckets)
            else:
                raise ValueError(f"Invalid metric type: {definition.type}")

        self._metrics = {
            definition.name: instance_metric(definition) for definition in METRICS_DEFINITIONS
        }

    def __getattr__(
        self, key: str
    ):  # no return type hint to not import prometheus_client at module level
        return self._metrics[key]


class DisaggPerfMetricsCollector:
    def __init__(self, max_requests: int):
        self._max_requests = max_requests
        self._request_meteics = deque(maxlen=max_requests)
        self._server_metrics = defaultdict(dict)
        self._lock = asyncio.Lock()
        self._clients = []

    def add_client(self, client):
        self._clients.append(client)

    async def add_per_request_metrics(
        self,
        ctx_server: str,
        gen_server: str,
        ctx_request_id: int,
        server_arrival_time: float,
        server_first_token_time: float,
    ):
        async with self._lock:
            self._request_meteics.append(
                (
                    ctx_server,
                    gen_server,
                    ctx_request_id,
                    server_arrival_time,
                    server_first_token_time,
                )
            )

    async def get_perf_metrics(self) -> List[Dict[str, Any]]:
        perf_metrics = {}
        for client in self._clients:
            metrics_dict = await client.collect_metrics()
            perf_metrics.update(metrics_dict)

        return_metrics = []
        async with self._lock:
            for server, metrics_data in perf_metrics.items():
                server_metrics = self._server_metrics[server]
                # avoid metrics map inflation by limiting the number of requests to add
                available_req_num = min(
                    max(0, self._max_requests - len(server_metrics)), len(metrics_data)
                )
                req_metrics_map = {
                    req_metrics["ctx_request_id"]: req_metrics
                    for req_metrics in metrics_data[:available_req_num]
                    if "ctx_request_id" in req_metrics
                }
                server_metrics.update(req_metrics_map)

            remain_keys = []
            for (
                ctx_server,
                gen_server,
                ctx_request_id,
                server_arrival_time,
                server_first_token_time,
            ) in self._request_meteics:
                gen_perf_metrics = self._server_metrics[gen_server].pop(ctx_request_id, None)
                if gen_perf_metrics is None:
                    # generation not finished
                    remain_keys.append(
                        (
                            ctx_server,
                            gen_server,
                            ctx_request_id,
                            server_arrival_time,
                            server_first_token_time,
                        )
                    )
                    continue
                ctx_perf_metrics = self._server_metrics[ctx_server].pop(ctx_request_id, None)
                # TODO: strip the keys for less repeating and use table style response
                return_metrics.append(
                    {
                        "ctx_server": ctx_server,
                        "gen_server": gen_server,
                        "disagg_server_arrival_time": server_arrival_time,
                        "disagg_server_first_token_time": server_first_token_time,
                        "ctx_perf_metrics": ctx_perf_metrics,
                        "gen_perf_metrics": gen_perf_metrics,
                    }
                )
            self._request_meteics = deque(remain_keys, maxlen=self._max_requests)
        return return_metrics
